【BQML応用記事】BigQuery MLで作った機械学習のモデルでオンライン予測を実施する

【BQML応用記事】BigQuery MLで作った機械学習のモデルでオンライン予測を実施する

Clock Icon2020.12.25

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、Mr.Moです。

このエントリは『クラスメソッド BigQuery Advent Calendar 2020』25本目のエントリです。本アドベントカレンダーもこのエントリで最後とのことで、これまで弊社のメンバーで書いてきた記事もご覧いただきながら最後までお楽しみいただければと思います。

当エントリではBigQuery ML(以下、BQML)のさらなる実用的な使い方である、オンライン予測ができるまでを見ていこうと思います。

BQMLとは

BQMLの概要については下記の記事にまとめてますのでご覧いただければと思います。

オンライン予測ができるまでをやっていく

今回はkaggleのタイタニックのデータを使って生存者の予測をするというお題で進めていこうと思います。

データの内容は下記のようになっています。このデータからタイタニック号での生存を予測(survivalの0,1を予測)しようというものです。

変数 定義 key
survival 生存 1 = 生存, 0 = 生存しない
pclass チケットのクラス 1 = 1st, 2 = 2nd, 3 = 3rd
sex 性別
age 年齢
sibsp タイタニック号に乗っている兄弟/配偶者の数
parch タイタニック号に乗っている親/子の数
ticket チケット番号
fare 旅客運賃
cabin キャビン番号
embarked 乗船した港 C = Cherbourg, Q = Queenstown, S = Southampton

先にトレーニングデータ(train.csv)をデータセットに追加しておきます。

image.png

モデルのトレーニング

先ほど追加したデータを使ってトレーニングを行います。BQMLならSQLで簡単にトレーニングも実施できますね。 使うモデルはXGBoostでいこうと思います。(ちなみに最初はAutoML Tablesを使う予定でしたが、オンライン予測はまだ対応していませんでした

CREATE OR REPLACE MODEL Titanic.xgboost_model
OPTIONS(
MODEL_TYPE='boosted_tree_classifier', INPUT_LABEL_COLS=["Survived"]
) AS 
SELECT
    * EXCEPT(PassengerId, Name, Ticket, Fare, Cabin)
  FROM
    `Titanic.train`

image.png

モデルのエクスポート

モデルのトレーニングが完了したら、Cloud Storage バケットにモデルをエクスポートします。 Cloud Shell上で下記のコマンドを実行します。

$ gsutil mb gs://titanic-2020
$ bq extract --destination_format ML_XGBOOST_BOOSTER -m Titanic.xgboost_model gs://titanic-2020/xgboost_model

モデルのデプロイ

そしてモデルをデプロイしていきます。デプロイ以降はAI Platformサービスを使います。

モデルリソースを作成し(ここでregionの選択を求められますが[1] globalを選択しました)

$ MODEL_NAME="TITANIC_XGBOOST_MODEL"
$ gcloud ai-platform models create $MODEL_NAME

モデル バージョンを作成して

$ MODEL_DIR="gs://titanic-2020/xgboost_model"
$ VERSION_NAME="v1"
$ gcloud beta ai-platform versions create $VERSION_NAME --model=$MODEL_NAME --origin=$MODEL_DIR --package-uris=${MODEL_DIR}/xgboost_predictor-0.1.tar.gz--prediction-class=predictor.Predictor --runtime-version=1.15 --machine-type="mls1-c1-m2"

エラーが出なければOKです。 さっそくコマンドで予測を実行してみましょう。

$ vi instances.json
$ INPUT_DATA_FILE="instances.json"

$ cat instances.json
{"Pclass":2, "Sex":"male", "Age":33, "SibSp":1, "Parch":2, "Embarked":"Q"}

$ gcloud ai-platform predict --model $MODEL_NAME --version $VERSION_NAME --json-instances $INPUT_DATA_FILE

image.png

無事、予測結果が返ってきましたね!

サービスアカウントの作成

もう少し色んなところで使えるようにしていきたいと思います。curlやプログラムでも作った予測モデルを使えるようにサービスアカウントを作成して実行権限を付与していきます。 そしてjson形式のkey fileを作成・ダウンロードします。(ai-work-275303-xxx.jsonのファイル)

image.png

curlで予測実行

下準備が整いましたのでまずは軽くcurlから予測を実行してみます。

$ export GOOGLE_APPLICATION_CREDENTIALS="/home/takashi1_kawamoto/ai-work-275303-b0d2af79eeab.json"
$ export YOUR_PROJECT_ID="ai-work-275303"
$ export YOUR_MODEL_NAME="TITANIC_XGBOOST_MODEL"
$ curl -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
    -H "Content-Type: application/json" \
    -X POST \
    -d '{"instances":[{"Pclass":2, "Sex":"male", "Age":33, "SibSp":1, "Parch":2, "Embarked":"Q"}]}' \
    https://ml.googleapis.com/v1/projects/${YOUR_PROJECT_ID}/models/${YOUR_MODEL_NAME}:predict

image.png

アプリに組み込んで予測実行

ここまででも予測は実行できてますが、やはりアプリっぽいものから実行する方が楽しいと思うので対応していきます。LINEのBot(Messaging API)なら簡単に要件を満たせるのでこの方向で進めていきましょう。

ここから先はアプリ化に必要な流れや情報をかなり簡単に記載していきます。(雰囲気だけ感じていただければ、もしこの情報だけで手を動かせそうでしたらぜひ!)

下記にざっくり必要なものを記載します。

フォルダ構成は下記です。

.
├── ai-work-275303-b0d2af79eeab.json
├── app.py
└── Dockerfile

各ファイルの中身は下記です。

  • app.py
import json
import os

from flask import Flask, abort, request
from google.cloud import datastore
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials

from linebot import (
    LineBotApi, WebhookHandler
)
from linebot.exceptions import (
    InvalidSignatureError
)
from linebot.models import (
    MessageEvent, TextMessage, TextSendMessage, StickerSendMessage, PostbackEvent, PostbackAction, QuickReply, QuickReplyButton
)

client = datastore.Client()

app = Flask(__name__)

line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN', None))
handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET', None))

sa_keyfile = os.getcwd() + '/ai-work-275303-63c1bb1f0331.json'
PROJECT_ID = os.getenv('PROJECT_ID', None)
MODEL_NAME = os.getenv('MODEL_NAME', None)
URL = 'https://ml.googleapis.com/v1/projects/{}/models/{}:predict'

@app.route('/callback', methods=['POST'])
def callback():
    # get X-Line-Signature header value
    signature = request.headers['X-Line-Signature']

    # get request body as text
    body = request.get_data(as_text=True)
    app.logger.info('Request body: ' + body)

    # handle webhook body
    try:
        handler.handle(body, signature)
    except InvalidSignatureError:
        print('Invalid signature. Please check your channel access token/channel secret.')
        abort(400)

    return 'OK'


@handler.add(MessageEvent, message=TextMessage)
def handle_message(event):
    key = client.key('Titanic', event.source.user_id)
    entity = datastore.Entity(key=key)
    result = client.get(key)

    if event.message.text == '予測':
        entity.update({
            'question': '0',
        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(
                text='チケットのクラスはどれですか?',
                quick_reply=QuickReply(
                    items=[
                        QuickReplyButton(
                            action=PostbackAction(label='1st', data='1', display_text='1st')
                        ),
                        QuickReplyButton(
                            action=PostbackAction(label='2nd', data='2', display_text='2nd')
                        ),
                        QuickReplyButton(
                            action=PostbackAction(label='3rd', data='3', display_text='3rd')
                        )
                    ])))  

    elif result.get('question') == '2':
        entity.update({
            'question': '3',
            'Pclass': result.get('Pclass'),
            'Sex': result.get('Sex'),
            'Age': event.message.text,
        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(text='乗船している兄弟・配偶者の人数を教えてください。')) 

    elif result.get('question') == '3':
        entity.update({
            'question': '4',
            'Pclass': result.get('Pclass'),
            'Sex': result.get('Sex'),
            'Age': result.get('Age'),
            'SibSp': event.message.text
        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(text='乗船している両親・子供の人数を教えてください。'))   

    elif result.get('question') == '4':
        entity.update({
            'question': '5',
            'Pclass': result.get('Pclass'),
            'Sex': result.get('Sex'),
            'Age': result.get('Age'),
            'SibSp': result.get('SibSp'),
            'Parch': event.message.text
        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(
                text='乗船した港はどれですか?',
                quick_reply=QuickReply(
                    items=[
                        QuickReplyButton(
                            action=PostbackAction(label='Cherbourg', data='C', display_text='Cherbourg')
                        ),
                        QuickReplyButton(
                            action=PostbackAction(label='Queenstown', data='Q', display_text='Queenstown')
                        ),
                        QuickReplyButton(
                            action=PostbackAction(label='Southampton', data='S', display_text='Southampton')
                        )
                    ])))  

    else:
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(
                text='『予測』とメッセージを送ってみてください。'))


@handler.add(PostbackEvent)
def handle_postback(event):
    key = client.key('Titanic', event.source.user_id)
    entity = datastore.Entity(key=key)
    result = client.get(key)
    app.logger.info(result)

    if result.get('question') == '0':
        entity.update({
            'question': '1',
            'Pclass': event.postback.data,
        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(
                text='性別はどちらですか?',
                quick_reply=QuickReply(
                    items=[
                        QuickReplyButton(
                            action=PostbackAction(label='男性', data='male', display_text='男性')
                        ),
                        QuickReplyButton(
                            action=PostbackAction(label='女性', data='female', display_text='女性')
                        )
                    ])))

    elif result.get('question') == '1':
        entity.update({
            'question': '2',
            'Pclass': result.get('Pclass'),
            'Sex': event.postback.data,

        })
        client.put(entity)
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(text='年齢はいくつですか?')) 

    elif result.get('question') == '5':
        entity.update({
            'question': '5',
            'Pclass': result.get('Pclass'),
            'Sex': result.get('Sex'),
            'Age': result.get('Age'),
            'SibSp': result.get('SibSp'),
            'Parch': result.get('Parch'),
            'Embarked': event.postback.data,
        })
        client.put(entity)

        inputs_for_prediction = [
            {'Pclass':result.get('Pclass'), 'Sex':result.get('Sex'), 'Age':result.get('Age'), 'SibSp':result.get('SibSp'), 'Parch':result.get('Parch'), 'Embarked':event.postback.data}
        ]

        credentials = GoogleCredentials.from_stream(sa_keyfile)
        service = discovery.build('ml', 'v1', credentials=credentials)

        name = 'projects/{}/models/{}'.format(PROJECT_ID, MODEL_NAME)

        response = service.projects().predict(
            name=name,
            body={'instances': inputs_for_prediction}
        ).execute()

        predicted_survived = response['predictions'][0]['predicted_Survived']

        if predicted_survived == '1':
            line_bot_api.reply_message(
                event.reply_token,
                [TextSendMessage(
                    text='安心してください。あなたは無事に帰ってこれるでしょう。'),
                 StickerSendMessage(
                    package_id='11537',
                    sticker_id='52002735')]) 
        elif predicted_survived == '0':
            line_bot_api.reply_message(
                event.reply_token,
                [TextSendMessage(
                    text='あなたには困難な運命が待ち受けている...かもしれません...'),
                 StickerSendMessage(
                    package_id='11537',
                    sticker_id='52002755')]) 

    else:
        line_bot_api.reply_message(
            event.reply_token,
            TextSendMessage(text='『予測』とメッセージを送ってみてください。'))


if __name__ == '__main__':
    app.run()

  • Dockerfile
# Use the official Python image.
# https://hub.docker.com/_/python
FROM python:3.7

# Copy local code to the container image.
ENV APP_HOME /app
WORKDIR $APP_HOME
COPY . .

# Install production dependencies.
RUN pip install Flask gunicorn line-bot-sdk google-cloud-datastore oauth2client google-api-python-client

# Run the web service on container startup. Here we use the gunicorn
# webserver, with one worker process and 8 threads.
# For environments with multiple CPU cores, increase the number of workers
# to be equal to the cores available.
CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 app:app

GitHub Codespacesの拡張機能にはCloud runのデプロイなどを操作できるCloud Codeがあるのでそちらを使ってデプロイを実施します。

image.png

アプリへの組み込み完了しました。最近のテクノロジーを駆使するとアプリ開発はすごく楽になりますね!それでは本題のアプリから予測を実行していきましょう。

下記は私っぽい情報を入れたところです...ちょっとダメそうな予測結果が返ってきました...

bqml-online-predict.gif

ちなみに妻は...

bqml-online-predict2.gif

おお!良かった!無事のようですε-(´∀`*)ホッ

まとめ

BQMLで作成したモデルをオンライン予測するまでの一通りの流れを見ていただきました。いちおう簡単にですがアプリっぽいものへの組み込みまでやっていきました。BQMLも機械学習の民主化を強く感じるものでしたが、昨今の開発の多くが非常にハードルが下がっていて頭にイメージしたものがクイックに実現できる素晴らしい世の中になっているなぁと感動もしておりました。ぜひ皆さまもBigQuery、機械学習を含む自分のやりたいことをこの冬休みに考えたり実行してみたりしてはいかがでしょうか?その際に本アドベントカレンダーが皆さまの助けになれれば幸いです。

参考

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.